# IMQ_analysis.R --------------------------------------------------------

# Author: vgouirand 
# Description: differential expression analysis of bulk-RNAseq data
# Input: here::here("input", "outs") 
# Output: here::here("output", "IMQ_counts.csv")
# Date: 2023_07_25

# Library Import ----------------------------------------------------------

library(conflicted)
library(AnnotationDbi)
library(biomaRt)
library(edgeR)
library(fgsea)
library(GenomicFeatures)
library(here)
library(limma)
library(msigdbr)
library(readxl)
library(tidyverse)
library(tximeta)
library(tximport)
library(VennDiagram)

options(width = 80)

filter <- dplyr::filter

# Metadata  ---------------------------------------------------------------
# Import metadata. 

IMQ_metadata <- read_xlsx(path = here("input", "IMQ_annotations.xlsx"))

IMQ_metadata <- 
  IMQ_metadata %>% 
  mutate(
    ID = factor(
      x = paste0(Cell_Type, "_", Group),
      levels = c(
        "Teff_LAYN_flox",
        "Teff_CTRL", 
        "Treg_LAYN_flox", 
        "Treg_CTRL"
      )
    ),
  )

# Gene References ---------------------------------------------------------

GRCm39 <- makeTxDbFromGFF(
  file = here("input", "Mus_musculus.GRCm39.110.gtf.gz"), 
  format = "gtf",
  organism = "Mus musculus"
)

# tx2gene
ttg <- AnnotationDbi::select(
  x = GRCm39, 
  keys = keys(x = GRCm39, keytype = "TXNAME"),
  "GENEID", 
  "TXNAME"
)

# biomaRt
# 20230228
# biomaRt::listEnsembl()
#          biomart                version
# 1         genes      Ensembl Genes 110
# 2 mouse_strains      Mouse strains 110
# 3          snps  Ensembl Variation 110
# 4    regulation Ensembl Regulation 110

# ensembl <- useEnsembl(
#   biomart = 'genes', 
#   dataset = 'mmusculus_gene_ensembl', 
#   version = 110
# )

# write_rds(x = ensembl, file = here("input", "mmusculus_gene_ensembl_110_mart.rds"))
ensembl <- read_rds(file = here("input", "mmusculus_gene_ensembl_110_mart.rds"))

key <- getBM(
  attributes = c("ensembl_gene_id", "mgi_symbol"),
  mart = ensembl
)


# Import Abundances -------------------------------------------------------

IMQ_abundances <- 
  list.files(
    path = here("input", "outs"),
    full.names = TRUE,
    recursive = TRUE
  ) %>%
  grep(
    pattern = "/abundance.tsv$",
    value = TRUE
  ) %>% 
  setNames(nm = IMQ_metadata$samples)



# TxImport ----------------------------------------------------------------

txi <- tximport(
  files = IMQ_abundances, 
  tx2gene = ttg, 
  type = "kallisto", 
  ignoreTxVersion = TRUE, # ignores "." in target_id 
  txIn = TRUE, # transcript input (ENST)
  txOut = FALSE, # summarize transcript-level estimates to gene estimates   
  countsFromAbundance = "lengthScaledTPM"
)


# Limma -------------------------------------------------------------------


mgi_counts <- txi$counts %>%
  as.data.frame() %>%
  rownames_to_column("genes") %>%
  left_join(y=key, by = c("genes" = "ensembl_gene_id")) %>%
  relocate(mgi_symbol, .after = genes)

unique_mgi_counts <- mgi_counts %>%
  distinct(genes, .keep_all = TRUE)


y <- DGEList(
  counts = txi$counts,
  samples = IMQ_metadata,
  group = IMQ_metadata$ID
  # genes = unique_mgi_counts$mgi_symbol # match genes in counts to gene order in key
)


# layn counts
y$counts[unique_mgi_counts$mgi_symbol == "Foxp3", , drop = FALSE] %>%
t() %>%
 as_tibble(rownames = "Group") %>%
  full_join(y = y$samples, by = c("Group")) %>%
  ggplot(aes(x = samples, y = ENSMUSG00000039521)) +
  geom_col(aes(fill = ID)) +
  labs(
    y = "Foxp3",
    x = NULL
  ) +
  scale_y_continuous(expand = expansion(mult = c(0, 0.5))) +
  theme_classic() +
  theme(axis.text.x = element_text(angle = 45, hjust = 1))


# total counts
y$counts %>% 
  colSums() %>% 
  as_tibble(rownames = "samples") %>% 
  full_join(y = y$samples, by = c("samples")) %>% 
  ggplot(aes(x = samples, y = value)) + 
  geom_col(aes(fill = ID)) + 
  labs(
    y = "Counts",
    x = NULL
  ) + 
  scale_y_continuous(expand = expansion(mult = c(0, 0.5))) + 
  theme_classic() + 
  theme(axis.text.x = element_text(angle = 45, hjust = 1))

plotMDS(x = y, col = scales::hue_pal()(1)[factor(y$samples)], cex = 0.75)


# remove Sk_Teff_WT_5
y <- DGEList(
  counts = txi$counts[, !startsWith(colnames(txi$counts), prefix = "Sk_Teff_WT_5")],
  samples = IMQ_metadata$samples[IMQ_metadata$samples !="Sk_Teff_WT_5"],
  group = IMQ_metadata$ID[IMQ_metadata$samples !="Sk_Teff_WT_5"]
)

IMQ_metadata <- IMQ_metadata[IMQ_metadata$samples !="Sk_Teff_WT_5", ]


design <- model.matrix(~ 0 + ID,data= IMQ_metadata)
keep <- filterByExpr(y = y, design = design)
y <- y[keep, , keep.lib.sizes = FALSE]
y <- calcNormFactors(y)
v <- voom(y, design, plot = TRUE)
# v2 <- voomWithQualityWeights(counts = y, design = design, plot = TRUE)

contrasts <- makeContrasts(
  Treg_LAYN_flox_v_Treg_CTRL = IDTreg_LAYN_flox-IDTreg_CTRL,# coef = 1
  Teff_LAYN_flox_v_Teff_CTRL = IDTeff_LAYN_flox-IDTeff_CTRL,# coef = 2
  Teff_LAYN_flox_v_Treg_LAYN_flox = IDTeff_LAYN_flox-IDTreg_LAYN_flox, # coef = 3
  Teff_CTRL_v_Treg_CTRL = IDTeff_CTRL-IDTreg_CTRL,# coef = 4
  levels = colnames(design)
)

lmFit <- lmFit(
  object = v,
  design = design
  # weights = v$weights
)

contrastsFit <- contrasts.fit(
  fit = lmFit,
  contrasts = contrasts
)

ebayes <- eBayes(
  fit = contrastsFit,
  trend = TRUE
)

degs <- mapply(
  FUN = function(x, y) {
    tt <- topTable(
      fit = ebayes,
      coef = y,
      adjust.method = "BH",
      sort.by = "none",
      number = Inf
    )
    tt <- rownames_to_column(tt, var = "Ensembl")
  },
  x = attributes(contrasts)$dimnames$Contrasts, # names
  y = 1:ncol(contrasts),
  SIMPLIFY = FALSE
) %>%
  bind_rows(.id = "contrast") %>%
  separate(
    col = contrast,
    into = c("group1", "group2"),
    sep = "_v_",
    remove = FALSE
  )

degs_mgi <- left_join(
  x = degs,
  y = key, 
  by = c("Ensembl" = "ensembl_gene_id"),
  multiple = "all"
)


degs_mgi %>% 
  ggplot(aes(x = logFC, y = -log10(adj.P.Val))) + 
  geom_point(aes(color = ifelse(test = mgi_symbol == "Layn", "red", "black"))) + 
  ggrepel::geom_text_repel(aes(label = ifelse(test = mgi_symbol == "Layn", mgi_symbol, ""))) +
  scale_color_identity() + 
  facet_wrap(~ contrast, nrow = 1) + 
  theme_classic() + 
  theme(strip.background = element_blank(), strip.text.x = element_text(face = "bold")) + 
  coord_cartesian(clip = "off")

degs_mgi %>% dplyr::filter(adj.P.Val < 0.05) %>% dplyr::count(contrast)
degs_mgi %>% dplyr::filter(adj.P.Val < 0.05) %>% View()

degs_mgi %>% 
  dplyr::filter(adj.P.Val < 0.05) %>% 
  dplyr::count(contrast) %>% 
  ggplot(aes(x = contrast, y = n)) + 
  geom_col(aes(fill = contrast)) + 
  geom_text(aes(label = n), size = 8 / ggplot2:::.pt, nudge_y = 100) +
  scale_y_continuous(expand = expansion(mult = c(0, 0.05))) + 
  guides(fill = guide_legend(nrow = 1)) + 
  theme_classic() + 
  labs(
    x = NULL,
    y = "No. DEGs"
  ) + 
  theme(legend.position = "none", axis.text.x = element_text(angle = 35, hjust = 1)) 

degs_mgi %>% 
  dplyr::filter(contrast %in% c("Treg_LAYN_flox_v_Treg_CTRL"), adj.P.Val < 0.055) %>% 
  group_by(contrast) %>% View()
  mutate(name_count = n()) %>%
  ungroup() %>% 
  filter(name_count == 2) %>% 
  dplyr::select(-name_count)


write_csv(x = degs_mgi, file = here("output", "degs_mgi.csv"))

# fgsea -------------------------------------------------------------------

# all  
msigdb <- msigdbr::msigdbr(species = "Mus musculus")

msigdb_split <- msigdb %>% split(x = .$gene_symbol, f = .$gs_name)

# gene onotology

go <- msigdbr::msigdbr(species = "Mus musculus", category = "C5")
go_split <- go %>% dplyr::filter(grepl("^GO\\:", gs_subcat)) %>% split(x = .$gene_symbol, f = .$gs_name)


Treg_KO_v_Treg_CTRL <- degs_mgi %>%
  filter(contrast == "Treg_LAYN_flox_v_Treg_CTRL") %>%
  dplyr::filter(mgi_symbol != "") %>%
  dplyr::arrange(-logFC) %>%
  dplyr::select(mgi_symbol, logFC) %>%
  tibble::deframe()

rank_list <- vector(mode = "list")
for (i in unique(degs_mgi$contrast)) {
  message("Generating gene set ", i)
  rank_list[[i]] <- degs_mgi %>% 
    filter(contrast == i) %>% 
    dplyr::filter(mgi_symbol != "") %>% 
    dplyr::select(mgi_symbol, logFC) %>%
    dplyr::group_by(mgi_symbol) %>% 
    dplyr::arrange(-logFC) %>% 
    dplyr::slice_sample(n = 1) %>% 
    ungroup() %>% 
    dplyr::arrange(-logFC) %>% 
    tibble::deframe()
}

rank_list[[1]] %>% length()
# 23182

names(rank_list)

# fgsea_go 
set.seed(42L)
fgsea_go <- mapply(
  FUN = function(x, y) {
    message("Running fgsea on ", y,  " contrast.")
    res <- fgsea::fgseaMultilevel(
      pathways = go_split,
      stats = x,
      eps = 0,
      minSize = 15,
      scoreType = "std",
      nproc = 8
    )
    return(res)
  },
  x = rank_list,
  y = names(rank_list),
  SIMPLIFY = FALSE,
  USE.NAMES = TRUE
)


fgsea_go <- bind_rows(fgsea_go, .id = "contrast")
fgsea_go <- fgsea_go %>% mutate(leadingEdge = sapply(leadingEdge, toString))

write_csv(x = fgsea_go, file = here("output", "fgsea_go.csv"))

fgsea_res <- lapply(
  X = rank_list,
  FUN = function(x, y) {
    message("Running fgsea on ", y,  " contrast.")
    res <- fgsea::fgseaMultilevel(
      pathways = go_split,
      stats = x,
      # sampleSize = 101,
      eps = 0,
      # nperm = 1E6, # ???
      minSize = 15, # ???
      # maxSize = 500, # ???
      scoreType = "std",
      nproc = 8
    )
    return(res)
  },
  y = names(rank_list)
)


fgsea_res <- fgsea::fgseaSimple(
  pathways = go_split,
  stats = Treg_KO_v_Treg_CTRL,
  nperm = 10000,
  minSize = 50,
  maxSize = 500,
  scoreType = "std",
  nproc = 8
)

topPathwaysUp <- fgsea_res[ES > 0][head(order(pval), n=10), pathway]
topPathwaysDown <- fgsea_res[ES < 0][head(order(pval), n=10), pathway]
topPathways <- c(topPathwaysUp, rev(topPathwaysDown))
fgsea::plotGseaTable(go_split[topPathways], Treg_KO_v_Treg_CTRL, fgsea_res,
              gseaParam=0.5)